{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T13:39:09.699430Z",
     "start_time": "2019-12-01T13:39:09.108012Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "\n",
    "from art.attacks.fast_gradient import FastGradientMethod\n",
    "from art.classifiers.pytorch import PyTorchClassifier\n",
    "from art.utils import load_mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T13:46:02.295675Z",
     "start_time": "2019-12-01T13:46:02.279475Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 0: Define the neural network model, return logits instead of activation in forward method\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Sequential(   # (1,28,28)   \n",
    "            # in_channels:输入的通道数， out_channels：卷积核数量， kernel_size：卷积核大小， stride：步长\n",
    "            # stride=1时， padding=(kernel_size-1)/2， 图片长宽不变\n",
    "            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),    # (16,28,28)\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),    # (16,14,14)\n",
    "        )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),    # (32,14,14)\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),    # （32,7,7）\n",
    "        )\n",
    "        self.fc1 = nn.Sequential(\n",
    "            nn.Linear(32*7*7, 500),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)    # (batch,32,7,7)\n",
    "        x = x.view(x.size(0), -1)    # (batch,32*7*7)\n",
    "        x = self.fc1(x)\n",
    "        output = self.fc2(x)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:10:33.063959Z",
     "start_time": "2019-12-01T14:10:32.179396Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 1: Load the MNIST dataset\n",
    "\n",
    "(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_mnist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:10:33.069789Z",
     "start_time": "2019-12-01T14:10:33.065990Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "min_pixel_value:  0.0\n",
      "max_pixel_value:  1.0\n"
     ]
    }
   ],
   "source": [
    "print('min_pixel_value: ', min_pixel_value)\n",
    "print('max_pixel_value: ', max_pixel_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:10:33.316747Z",
     "start_time": "2019-12-01T14:10:33.071408Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 1a: Swap axes to PyTorch's NCHW format\n",
    "\n",
    "x_train = x_train.transpose(0, 3, 1, 2).astype(np.float32)\n",
    "x_test = x_test.transpose(0, 3, 1, 2).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:10:36.609872Z",
     "start_time": "2019-12-01T14:10:36.590949Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 2: Create the model\n",
    "\n",
    "model = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:10:39.750870Z",
     "start_time": "2019-12-01T14:10:39.745151Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 2a: Define the loss function and the optimizer\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:10:49.389170Z",
     "start_time": "2019-12-01T14:10:49.359559Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 3: Create the ART classifier\n",
    "\n",
    "classifier = PyTorchClassifier(model=model, clip_values=(min_pixel_value, max_pixel_value), loss=criterion,\n",
    "                               optimizer=optimizer, input_shape=(1, 28, 28), nb_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:11:31.567038Z",
     "start_time": "2019-12-01T14:11:19.279537Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 4: Train the ART classifier\n",
    "\n",
    "classifier.fit(x_train, y_train, batch_size=64, nb_epochs=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:11:38.629733Z",
     "start_time": "2019-12-01T14:11:38.484844Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy on benign test examples: 97.8%\n"
     ]
    }
   ],
   "source": [
    "# Step 5: Evaluate the ART classifier on benign test examples\n",
    "\n",
    "predictions = classifier.predict(x_test)\n",
    "accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n",
    "print('Accuracy on benign test examples: {}%'.format(accuracy * 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:12:45.236423Z",
     "start_time": "2019-12-01T14:11:48.625213Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 6: Generate adversarial test examples\n",
    "attack = FastGradientMethod(classifier=classifier, eps=0.2)\n",
    "x_test_adv = attack.generate(x=x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-01T14:12:45.342107Z",
     "start_time": "2019-12-01T14:12:45.238852Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy on adversarial test examples: 13.52%\n"
     ]
    }
   ],
   "source": [
    "# Step 7: Evaluate the ART classifier on adversarial test examples\n",
    "\n",
    "predictions = classifier.predict(x_test_adv)\n",
    "accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n",
    "print('Accuracy on adversarial test examples: {}%'.format(accuracy * 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
